import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.integrate import cumulative_trapezoid
from scipy.interpolate import interp1d

# --- CONFIGURATION ---
INPUT_FILE = 'produced/refined_actual_data_1701.csv'
OUTPUT_PLOT = 'plots/hubble_residuals_bridge.png'
C_LIGHT = 299792.458
K_EXP = 4224.0

def get_mu_refractive_smooth(z_grid, K):
    r_max = K * np.log(1.0 + np.max(z_grid)) * 1.5 
    r_vals = np.linspace(0, r_max, 1000)
    v_vals = C_LIGHT * (1.0 - np.exp(-r_vals / K))
    n_vals = 1.0 + (v_vals/C_LIGHT)**2 
    d_opt_vals = cumulative_trapezoid(n_vals, r_vals, initial=0.0)
    z_vals_r = (v_vals/C_LIGHT) / (1.0 - v_vals/C_LIGHT)
    dL_vals = d_opt_vals * (1.0 + z_vals_r)
    interp_dL = interp1d(z_vals_r, dL_vals, kind='cubic', fill_value='extrapolate')
    
    valid = z_grid > 1e-4
    mu_vals = np.zeros_like(z_grid)
    dL_interp = interp_dL(z_grid[valid])
    mu_vals[valid] = 5.0 * np.log10(dL_interp) + 25.0
    return mu_vals

def main():
    print("Generating Bridge Plot...")
    if not os.path.exists(INPUT_FILE):
        print(f"Error: {INPUT_FILE} not found. Run data generation first.")
        return

    df = pd.read_csv(INPUT_FILE)
    df = df[df['z_obs'] > 0.001].copy()

    H0_REF = C_LIGHT / K_EXP
    mu_linear = 5.0 * np.log10(C_LIGHT * df['z_obs'] / H0_REF) + 25.0
    
    # Calculate Standard Residuals (Signal)
    signal_data = df['mu_obs'] - mu_linear

    # Calculate Theory Curve
    z_smooth = np.logspace(np.log10(0.001), np.log10(max(df['z_obs'])), 200)
    mu_linear_smooth = 5.0 * np.log10(C_LIGHT * z_smooth / H0_REF) + 25.0
    mu_refractive_smooth = get_mu_refractive_smooth(z_smooth, K_EXP)
    signal_theory = mu_refractive_smooth - mu_linear_smooth

    # Calculate Corrected Residuals
    mu_refractive_data = get_mu_refractive_smooth(df['z_obs'].values, K_EXP)
    residuals_corrected = df['mu_obs'] - mu_refractive_data

    # Plot
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), sharex=True)
    plt.subplots_adjust(hspace=0.05)
    
    ax1.scatter(df['z_obs'], signal_data, s=15, color='gray', alpha=0.3, label='Standard Residuals')
    ax1.plot(z_smooth, signal_theory, color='blue', linewidth=3, label='Refractive Prediction')
    ax1.set_ylabel(r'$\mu_{obs} - \mu_{linear}$ (mag)', fontsize=14)
    ax1.legend(loc='upper left')
    ax1.grid(True, alpha=0.3)
    ax1.axhline(0, color='k', lw=1, linestyle=':')

    ax2.scatter(df['z_obs'], residuals_corrected, s=15, color='royalblue', alpha=0.3, label='Corrected Residuals')
    ax2.axhline(0, color='k', linewidth=2, linestyle='-')
    ax2.set_xlabel('Redshift $z$', fontsize=14)
    ax2.set_ylabel(r'Residuals (mag)', fontsize=14)
    ax2.set_xscale('log')
    ax2.legend(loc='lower left')
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(-0.5, 0.5)
    
    os.makedirs('plots', exist_ok=True)
    plt.savefig(OUTPUT_PLOT, dpi=300, bbox_inches='tight')
    print(f"Saved {OUTPUT_PLOT}")

if __name__ == '__main__':
    main()